{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "text_classification_rnn.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.4"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hX4n9TsbGw-f"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "0nbI5DtDGw-i",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9TnJztDZGw-n"
      },
      "source": [
        "# RNN を使ったテキスト分類"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AfN3bMR5Gw-o"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/text/text_classification_rnn\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/ja/tutorials/text/text_classification_rnn.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/ja/tutorials/text/text_classification_rnn.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/ja/tutorials/text/text_classification_rnn.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tudzcncJXetB"
      },
      "source": [
        "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lUWearf0Gw-p"
      },
      "source": [
        "このテキスト分類チュートリアルでは、感情分析のために [IMDB 映画レビュー大型データセット](http://ai.stanford.edu/~amaas/data/sentiment/) を使って [リカレントニューラルネットワーク](https://developers.google.com/machine-learning/glossary/#recurrent_neural_network) を訓練します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_2VQo4bajwUU"
      },
      "source": [
        "## 設定"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "z682XYsrjkY9",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "try:\n",
        "  # %tensorflow_version は Colab 上でだけ使用可能\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1rXHa-w9JZhb"
      },
      "source": [
        "`matplotlib` をインポートしグラフを描画するためのヘルパー関数を作成します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Mp1Z7P9pYRSK",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def plot_graphs(history, string):\n",
        "  plt.plot(history.history[string])\n",
        "  plt.plot(history.history['val_'+string], '')\n",
        "  plt.xlabel(\"Epochs\")\n",
        "  plt.ylabel(string)\n",
        "  plt.legend([string, 'val_'+string])\n",
        "  plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pRmMubr0jrE2"
      },
      "source": [
        "## 入力パイプラインの設定\n",
        "\n",
        "IMDB 映画レビュー大型データセットは*二値分類*データセットです。すべてのレビューは、*好意的(positive)* または *非好意的(negative)* のいずれかの感情を含んでいます。\n",
        "\n",
        "[TFDS](https://www.tensorflow.org/datasets) を使ってこのデータセットをダウンロードします。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "SHRwRoP2nVHX",
        "colab": {}
      },
      "source": [
        "dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,\n",
        "                          as_supervised=True)\n",
        "train_dataset, test_dataset = dataset['train'], dataset['test']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MCorLciXSDJE"
      },
      "source": [
        " このデータセットの `info` には、エンコーダー(`tfds.features.text.SubwordTextEncoder`) が含まれています。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "EplYp5pNnW1S",
        "colab": {}
      },
      "source": [
        "encoder = info.features['text'].encoder"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "e7ACuHM5hFp3",
        "colab": {}
      },
      "source": [
        "print ('Vocabulary size: {}'.format(encoder.vocab_size))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tAfGg8YRe6fu"
      },
      "source": [
        "このテキストエンコーダーは、任意の文字列を可逆的にエンコードします。必要であればバイトエンコーディングにフォールバックします。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Bq6xDmf2SAs-",
        "colab": {}
      },
      "source": [
        "sample_string = 'Hello TensorFlow.'\n",
        "\n",
        "encoded_string = encoder.encode(sample_string)\n",
        "print ('Encoded string is {}'.format(encoded_string))\n",
        "\n",
        "original_string = encoder.decode(encoded_string)\n",
        "print ('The original string: \"{}\"'.format(original_string))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "TN7QbKaM4-5H",
        "colab": {}
      },
      "source": [
        "assert original_string == sample_string"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "MDVc6UGO5Dh6",
        "colab": {}
      },
      "source": [
        "for index in encoded_string:\n",
        "  print ('{} ----> {}'.format(index, encoder.decode([index])))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GlYWqhTVlUyQ"
      },
      "source": [
        "## 訓練用データの準備"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "z2qVJzcEluH_"
      },
      "source": [
        "次に、これらのエンコード済み文字列をバッチ化します。`padded_batch` メソッドを使ってバッチ中の一番長い文字列の長さにゼロパディングを行います。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "dDsCaZCDYZgm",
        "colab": {}
      },
      "source": [
        "BUFFER_SIZE = 10000\n",
        "BATCH_SIZE = 64"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VznrltNOnUc5",
        "colab": {}
      },
      "source": [
        "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
        "train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)\n",
        "\n",
        "test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bjUqGVBxGw-t"
      },
      "source": [
        "## モデルの作成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bgs6nnSTGw-t"
      },
      "source": [
        "`tf.keras.Sequential` モデルを構築しましょう。最初に Embedding レイヤーから始めます。Embedding レイヤーは単語一つに対して一つのベクトルを収容します。呼び出しを受けると、Embedding レイヤーは単語のインデックスのシーケンスを、ベクトルのシーケンスに変換します。これらのベクトルは訓練可能です。（十分なデータで）訓練されたあとは、おなじような意味をもつ単語は、しばしばおなじようなベクトルになります。\n",
        "\n",
        "このインデックス参照は、ワンホットベクトルを `tf.keras.layers.Dense` レイヤーを使って行うおなじような演算に比べてずっと効率的です。\n",
        "\n",
        "リカレントニューラルネットワーク（RNN）は、シーケンスの入力を要素を一つずつ扱うことで処理します。RNN は、あるタイムステップでの出力を次のタイムステップの入力へと、次々に渡していきます。\n",
        "\n",
        "RNN レイヤーとともに、`tf.keras.layers.Bidirectional` ラッパーを使用することができます。このラッパーは、入力を RNN 層の順方向と逆方向に伝え、その後出力を結合します。これにより、RNN は長期的な依存関係を学習できます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LwfoBkmRYcP3",
        "colab": {}
      },
      "source": [
        "model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Embedding(encoder.vocab_size, 64),\n",
        "    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),\n",
        "    tf.keras.layers.Dense(64, activation='relu'),\n",
        "    tf.keras.layers.Dense(1, activation='sigmoid')\n",
        "])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sRI776ZcH3Tf"
      },
      "source": [
        "訓練プロセスを定義するため、Keras モデルをコンパイルします。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kj2xei41YZjC",
        "colab": {}
      },
      "source": [
        "model.compile(loss='binary_crossentropy',\n",
        "              optimizer=tf.keras.optimizers.Adam(1e-4),\n",
        "              metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zIwH3nto596k"
      },
      "source": [
        "## モデルの訓練"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "hw86wWS4YgR2",
        "colab": {}
      },
      "source": [
        "history = model.fit(train_dataset, epochs=10,\n",
        "                    validation_data=test_dataset, \n",
        "                    validation_steps=30)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "BaNbXi43YgUT",
        "colab": {}
      },
      "source": [
        "test_loss, test_acc = model.evaluate(test_dataset)\n",
        "\n",
        "print('Test Loss: {}'.format(test_loss))\n",
        "print('Test Accuracy: {}'.format(test_acc))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DwSE_386uhxD"
      },
      "source": [
        "上記のモデルはシーケンスに適用されたパディングをマスクしていません。パディングされたシーケンスで訓練を行い、パディングをしていないシーケンスでテストするとすれば、このことが結果を歪める可能性があります。理想的にはこれを避けるために、 [マスキング](../../guide/keras/masking_and_padding)を使うべきですが、下記のように出力への影響は小さいものでしかありません。 \n",
        "\n",
        "予測値が 0.5 以上であればポジティブ、それ以外はネガティブです。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8w0dseJMiEUh",
        "colab": {}
      },
      "source": [
        "def pad_to_size(vec, size):\n",
        "  zeros = [0] * (size - len(vec))\n",
        "  vec.extend(zeros)\n",
        "  return vec"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Y-E4cgkIvmVu",
        "colab": {}
      },
      "source": [
        "def sample_predict(sentence, pad):\n",
        "  encoded_sample_pred_text = encoder.encode(sample_pred_text)\n",
        "\n",
        "  if pad:\n",
        "    encoded_sample_pred_text = pad_to_size(encoded_sample_pred_text, 64)\n",
        "  encoded_sample_pred_text = tf.cast(encoded_sample_pred_text, tf.float32)\n",
        "  predictions = model.predict(tf.expand_dims(encoded_sample_pred_text, 0))\n",
        "\n",
        "  return (predictions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "O41gw3KfWHus",
        "colab": {}
      },
      "source": [
        "# パディングなしのサンプルテキストの推論\n",
        "\n",
        "sample_pred_text = ('The movie was cool. The animation and the graphics '\n",
        "                    'were out of this world. I would recommend this movie.')\n",
        "predictions = sample_predict(sample_pred_text, pad=False)\n",
        "print (predictions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kFh4xLARucTy",
        "colab": {}
      },
      "source": [
        "# パディングありのサンプルテキストの推論\n",
        "\n",
        "sample_pred_text = ('The movie was cool. The animation and the graphics '\n",
        "                    'were out of this world. I would recommend this movie.')\n",
        "predictions = sample_predict(sample_pred_text, pad=True)\n",
        "print (predictions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ZfIVoxiNmKBF",
        "colab": {}
      },
      "source": [
        "plot_graphs(history, 'accuracy')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "IUzgkqnhmKD2",
        "colab": {}
      },
      "source": [
        "plot_graphs(history, 'loss')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7g1evcaRpTKm"
      },
      "source": [
        "## 2つ以上の LSTM レイヤーを重ねる\n",
        "\n",
        "Keras のリカレントレイヤーには、コンストラクタの `return_sequences` 引数でコントロールされる2つのモードがあります。\n",
        "\n",
        "* それぞれのタイムステップの連続した出力のシーケンス全体（shape が `(batch_size, timesteps, output_features)` の3階テンソル）を返す。\n",
        "* それぞれの入力シーケンスの最後の出力だけ（shape が `(batch_size, output_features)` の2階テンソル）を返す。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jo1jjO3vn0jo",
        "colab": {}
      },
      "source": [
        "model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Embedding(encoder.vocab_size, 64),\n",
        "    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64,  return_sequences=True)),\n",
        "    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),\n",
        "    tf.keras.layers.Dense(64, activation='relu'),\n",
        "    tf.keras.layers.Dropout(0.5),\n",
        "    tf.keras.layers.Dense(1, activation='sigmoid')\n",
        "])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "hEPV5jVGp-is",
        "colab": {}
      },
      "source": [
        "model.compile(loss='binary_crossentropy',\n",
        "              optimizer=tf.keras.optimizers.Adam(1e-4),\n",
        "              metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LeSE-YjdqAeN",
        "colab": {}
      },
      "source": [
        "history = model.fit(train_dataset, epochs=10,\n",
        "                    validation_data=test_dataset,\n",
        "                    validation_steps=30)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_LdwilM1qPM3",
        "colab": {}
      },
      "source": [
        "test_loss, test_acc = model.evaluate(test_dataset)\n",
        "\n",
        "print('Test Loss: {}'.format(test_loss))\n",
        "print('Test Accuracy: {}'.format(test_acc))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ykUKnAoqbycW",
        "colab": {}
      },
      "source": [
        "# パディングなしのサンプルテキストの推論\n",
        "\n",
        "sample_pred_text = ('The movie was not good. The animation and the graphics '\n",
        "                    'were terrible. I would not recommend this movie.')\n",
        "predictions = sample_predict(sample_pred_text, pad=False)\n",
        "print (predictions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2RiC-94zvdZO",
        "colab": {}
      },
      "source": [
        "# パディングありのサンプルテキストの推論\n",
        "\n",
        "sample_pred_text = ('The movie was not good. The animation and the graphics '\n",
        "                    'were terrible. I would not recommend this movie.')\n",
        "predictions = sample_predict(sample_pred_text, pad=True)\n",
        "print (predictions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_YYub0EDtwCu",
        "colab": {}
      },
      "source": [
        "plot_graphs(history, 'accuracy')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "DPV3Nn9xtwFM",
        "colab": {}
      },
      "source": [
        "plot_graphs(history, 'loss')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9xvpE3BaGw_V"
      },
      "source": [
        "[GRU レイヤー](https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU)など既存のほかのレイヤーを調べてみましょう。\n",
        "\n",
        "カスタム RNN の構築に興味があるのであれば、[Keras RNN ガイド](../../guide/keras/rnn.ipynb) を参照してください。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "UQK2Na3Z_aMx",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}